-
Notifications
You must be signed in to change notification settings - Fork 162
Update eagle notebook example with sglang #316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Update eagle notebook example with sglang #316
Conversation
WalkthroughThe notebook replaces an FP8/quantization-focused speculative decoding example with an EAGLE3-based workflow: it switches to meta-llama/Llama-3.2-1B, prepares data from Daring-Anteater, attaches an EAGLE draft head via mtsp.convert, updates training/export, and adds TRT-LLM and SGLang deployment steps. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Notebook
participant HF as HuggingFace Model
participant MT as mtsp.convert
participant DS as Dataset (/tmp/Daring-Anteater)
participant Trainer
participant FS as Filesystem
User->>Notebook: execute cells
Notebook->>HF: load base model & tokenizer (Llama-3.2-1B)
Notebook->>MT: mtsp.convert(model, [("eagle", cfg)]) -> attach EAGLE draft head
Notebook->>DS: load /tmp/Daring-Anteater/train.jsonl
Notebook->>Trainer: train (num_train_epochs=4)
Trainer-->>Notebook: trained weights
Notebook->>FS: export_hf_checkpoint -> /tmp/hf_ckpt
sequenceDiagram
autonumber
actor Client
participant TRT as TRT-LLM Server
participant SG as SGLang Server
participant DH as EAGLE Draft Head
participant BH as Base Model
rect rgba(200,230,255,0.18)
note over TRT: TRT-LLM speculative flow
Client->>TRT: Chat/completion request
TRT->>DH: request draft tokens
DH-->>TRT: draft + acceptance info
TRT->>BH: verify/complete tokens
BH-->>TRT: final tokens
TRT-->>Client: response
end
rect rgba(220,255,220,0.18)
note over SG: SGLang flow (similar)
Client->>SG: Chat request
SG->>DH: draft generation
SG->>BH: verify/complete
BH-->>SG: final tokens
SG-->>Client: response
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/example.ipynb (1)
114-121
: Only train the EAGLE head; freeze the target model.As-is, Trainer will update the full model, which is likely unintended/expensive.
Apply before creating Trainer:
+# Freeze base/target model; train only the EAGLE head +for name, p in model.named_parameters(): + p.requires_grad = ("eagle" in name.lower())If modelopt offers a helper (e.g., mtsp.freeze_base_model/enable_train("eagle")), prefer that.
🧹 Nitpick comments (8)
examples/speculative_decoding/example.ipynb (8)
7-8
: Call out prerequisites (HF auth + git‑lfs) up front.Daring‑Anteater via git and Llama‑3.2‑1B are likely gated; readers may need HF tokens and git‑lfs. Add a short “Prereqs” note here to reduce setup failures.
43-47
: HF gated model note.meta-llama/Llama-3.2-1B typically requires license acceptance and an HF token. Consider documenting HF auth/env var usage before this cell to prevent download errors.
62-63
: Capture/verify convert() return semantics.Some modelopt APIs return a new module; others mutate in place. Assign defensively or confirm it’s in‑place to avoid silently training the wrong module.
Apply:
-mtsp.convert(model, [("eagle", config)]) +_ret = mtsp.convert(model, [("eagle", config)]) +if _ret is not None: + model = _ret
108-113
: Set a seed for reproducibility.Add a deterministic seed to make the demo repeatable.
Apply before creating Trainer:
+transformers.set_seed(42)
318-320
: Make cleanup idempotent.Apply:
-%%sh -docker rm -f trtllm_serve_spec +%%sh +docker rm -f trtllm_serve_spec || true
375-387
: Propagate HF cache/token into SGLang container too (mirror TRT‑LLM).You already mount the cache—optionally plumb the token as an env var for gated models.
Apply:
- "-v", f"{hf_cache_dir}:/root/.cache/huggingface", + "-v", f"{hf_cache_dir}:/root/.cache/huggingface", + "-e", "HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface", + # "-e", f"HF_TOKEN={os.environ.get('HF_TOKEN','')}",
459-461
: Idempotent cleanup.Apply:
-%%sh -docker rm -f sglang_serve_spec +%%sh +docker rm -f sglang_serve_spec || true
480-480
: Kernel name nit.“modelopt+serve” is a clearer display_name since this notebook demos TRT‑LLM and SGLang, not vLLM (yet).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/speculative_decoding/example.ipynb
(12 hunks)
🔇 Additional comments (5)
examples/speculative_decoding/example.ipynb (5)
147-154
: LGTM: HF checkpoint export step is clear.
336-349
: LGTM: SGLang launch script looks reasonable.If multi‑GPU is common for your users, consider exposing tensor‑parallel flags in the script.
494-494
: Python 3.12 support confirmed (as of 2025-09-14). PyTorch, Transformers, NVIDIA TensorRT‑LLM, and SGLang official docs/releases list Python 3.12 support; no changes required.
183-197
: Speculative config — set decoding_type to "Eagle"; keep autotuner off; enable block reuse if supported.
- decoding_type must be "Eagle" (capital E).
- enable_autotuner is experimental — do not enable by default in production; use offline tuning and validate before enabling.
- kv_cache_config.enable_block_reuse is recommended for production (improves TTFT) only if the model was built with paged‑context/FMHA support (validate model build and behavior before enabling).
File: examples/speculative_decoding/example.ipynb (lines 183–197)
93-93
: Verify eagle_utils classes & pad-token masking.Confirm examples/speculative_decoding/eagle_utils.py defines DataCollatorWithPadding (found ~line 207) and LazySupervisedDataset, and verify DataCollatorWithPadding sets labels for pad tokens to -100 (i.e., masks positions where input_ids == pad_token_id, including the case pad_token_id == eos_token_id).
"%%sh\n", | ||
"git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater" | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid git clone for datasets; use huggingface_hub to dodge git‑lfs pitfalls.
git‑lfs often isn’t installed in notebook environments; snapshot_download is more reliable and cache‑friendly.
Apply this diff in the cell:
-%%sh
-git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater
+from huggingface_hub import snapshot_download
+snapshot_download(
+ repo_id="nvidia/Daring-Anteater",
+ repo_type="dataset",
+ local_dir="/tmp/Daring-Anteater",
+ local_dir_use_symlinks=False
+)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
"%%sh\n", | |
"git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater" | |
] | |
from huggingface_hub import snapshot_download | |
snapshot_download( | |
repo_id="nvidia/Daring-Anteater", | |
repo_type="dataset", | |
local_dir="/tmp/Daring-Anteater", | |
local_dir_use_symlinks=False | |
) |
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 17 to 19, the
notebook currently runs a shell git clone which can fail due to missing git‑lfs;
replace the shell git clone with a Python-based download using
huggingface_hub.snapshot_download to fetch the dataset to /tmp (or a configured
cache dir). Import huggingface_hub, call
snapshot_download(repo_id="nvidia/Daring-Anteater",
cache_dir="/tmp/Daring-Anteater" or allow default cache), and use the returned
path in subsequent cells; ensure to handle authentication/token if required and
add a short note to pip-install huggingface_hub if not present.
"# Read Default Config for EAGLE3\n", | ||
"config = EAGLE3_DEFAULT_CFG[\"config\"]\n", | ||
"\n", | ||
"# Hidden size and vocab size must match base model\n", | ||
"config[\"eagle_architecture_config\"].update(\n", | ||
" {\n", | ||
" \"hidden_size\": model.config.hidden_size,\n", | ||
" \"vocab_size\": model.config.vocab_size,\n", | ||
" \"draft_vocab_size\": model.config.vocab_size,\n", | ||
" \"max_position_embeddings\": model.config.max_position_embeddings,\n", | ||
" }\n", | ||
")\n", | ||
"\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not mutate the global default EAGLE3 config.
Updating EAGLE3_DEFAULT_CFG["config"] in place can affect subsequent runs/calls. Deep‑copy before mutation.
Apply:
-from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG
+from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG
+import copy
@@
-# Read Default Config for EAGLE3
-config = EAGLE3_DEFAULT_CFG["config"]
+# Read Default Config for EAGLE3 (copy to avoid global mutation)
+config = copy.deepcopy(EAGLE3_DEFAULT_CFG["config"])
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
"# Read Default Config for EAGLE3\n", | |
"config = EAGLE3_DEFAULT_CFG[\"config\"]\n", | |
"\n", | |
"# Hidden size and vocab size must match base model\n", | |
"config[\"eagle_architecture_config\"].update(\n", | |
" {\n", | |
" \"hidden_size\": model.config.hidden_size,\n", | |
" \"vocab_size\": model.config.vocab_size,\n", | |
" \"draft_vocab_size\": model.config.vocab_size,\n", | |
" \"max_position_embeddings\": model.config.max_position_embeddings,\n", | |
" }\n", | |
")\n", | |
"\n", | |
import copy | |
# Read Default Config for EAGLE3 (copy to avoid global mutation) | |
config = copy.deepcopy(EAGLE3_DEFAULT_CFG["config"]) | |
# Hidden size and vocab size must match base model | |
config["eagle_architecture_config"].update( | |
{ | |
"hidden_size": model.config.hidden_size, | |
"vocab_size": model.config.vocab_size, | |
"draft_vocab_size": model.config.vocab_size, | |
"max_position_embeddings": model.config.max_position_embeddings, | |
} | |
) |
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 49 to 61, the code
updates EAGLE3_DEFAULT_CFG["config"] in place which mutates the global default;
instead, create a deep copy of EAGLE3_DEFAULT_CFG["config"] (or of
EAGLE3_DEFAULT_CFG) into a local variable and perform the update on that copy,
then use the copied config for subsequent initialization so the global default
remains unchanged.
"# Prepare Tokenizer\n", | ||
"tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=1024)\n", | ||
"tokenizer.pad_token_id = tokenizer.eos_token_id\n", | ||
"if tokenizer.chat_template is None:\n", | ||
" tokenizer.chat_template = (\n", | ||
" \"{%- for message in messages %}\"\n", | ||
" \"{{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}\"\n", | ||
" \"{%- endfor %}\"\n", | ||
" )" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Align tokenizer limits to model; set padding side.
1024 may truncate unnecessarily if the base model supports more. Also set padding_side explicitly to avoid surprises in collators/loss masks.
Apply:
-tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=1024)
+tokenizer = transformers.AutoTokenizer.from_pretrained(
+ base_model,
+ model_max_length=getattr(model.config, "max_position_embeddings", 2048),
+)
+tokenizer.padding_side = "right"
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
"# Prepare Tokenizer\n", | |
"tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=1024)\n", | |
"tokenizer.pad_token_id = tokenizer.eos_token_id\n", | |
"if tokenizer.chat_template is None:\n", | |
" tokenizer.chat_template = (\n", | |
" \"{%- for message in messages %}\"\n", | |
" \"{{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}\"\n", | |
" \"{%- endfor %}\"\n", | |
" )" | |
# Prepare Tokenizer | |
tokenizer = transformers.AutoTokenizer.from_pretrained( | |
base_model, | |
model_max_length=getattr(model.config, "max_position_embeddings", 2048), | |
) | |
tokenizer.padding_side = "right" | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
if tokenizer.chat_template is None: | |
tokenizer.chat_template = ( | |
"{%- for message in messages %}" | |
"{{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}" | |
"{%- endfor %}" | |
) |
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 65–73, replace the
hardcoded model_max_length=1024 with the model's actual limit and explicitly set
padding_side: load the model config (AutoConfig.from_pretrained(base_model)) and
use its max_position_embeddings (or tokenizer.model_max_length if already
provided) to set tokenizer.model_max_length, then set tokenizer.padding_side =
"right" (or "left" if your training expects left padding) alongside
tokenizer.pad_token_id and keep the chat_template logic unchanged.
"with open(\"/tmp/Daring-Anteater/train.jsonl\") as f:\n", | ||
" data_json = [json.loads(line) for line in f]\n", | ||
"train_dataset = LazySupervisedDataset(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer)\n", | ||
"eval_dataset = LazySupervisedDataset(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid loading the entire JSONL into memory.
Stream or use datasets to handle large files robustly.
Apply:
-with open("/tmp/Daring-Anteater/train.jsonl") as f:
- data_json = [json.loads(line) for line in f]
-train_dataset = LazySupervisedDataset(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer)
-eval_dataset = LazySupervisedDataset(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer)
+import json
+def stream_jsonl(path):
+ with open(path) as f:
+ for line in f:
+ yield json.loads(line)
+all_data = list(stream_jsonl("/tmp/Daring-Anteater/train.jsonl"))
+split = int(len(all_data) * 0.95)
+train_dataset = LazySupervisedDataset(all_data[:split], tokenizer=tokenizer)
+eval_dataset = LazySupervisedDataset(all_data[split:], tokenizer=tokenizer)
Or swap to datasets.load_dataset("json", data_files=...).
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 96 to 99, the
notebook currently reads the entire JSONL into memory via json.loads on each
line; replace this with a streaming approach or use the datasets library. Either
(a) iterate the file and yield/process lines lazily (e.g., create generators or
a LazySupervisedDataset that reads from the file path rather than a pre-built
list) so you never materialize data_json, or (b) call
datasets.load_dataset("json", data_files=path, split="train",
streaming=True/with proper train/validation splits) and pass the resulting
dataset (or its iterator) into LazySupervisedDataset to avoid loading the full
file into memory. Ensure subsequent slicing/splitting is done via
streaming-aware methods (e.g., dataset.train_test_split or manual streaming
partition) rather than list indexing.
"@dataclass\n", | ||
"class TrainingArguments(transformers.TrainingArguments):\n", | ||
" cache_dir: str | None = field(default=None)\n", | ||
" model_max_length: int = field(\n", | ||
" default=4096,\n", | ||
" metadata={\n", | ||
" \"help\": (\n", | ||
" \"Maximum sequence length. Sequences will be right padded (and possibly truncated).\"\n", | ||
" )\n", | ||
" },\n", | ||
" )\n", | ||
" dataloader_drop_last: bool = field(default=True)\n", | ||
" bf16: bool = field(default=True)\n", | ||
"\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Don’t subclass TrainingArguments for minor toggles.
Subclassing a dataclass here adds maintenance risk for little gain. Pass bf16 and dataloader_drop_last directly.
Apply:
-@dataclass
-class TrainingArguments(transformers.TrainingArguments):
- dataloader_drop_last: bool = field(default=True)
- bf16: bool = field(default=True)
+TrainingArguments = transformers.TrainingArguments
And below (Line 108):
training_args = TrainingArguments(
output_dir="/tmp/eagle_bf16",
num_train_epochs=4,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
+ bf16=True,
+ dataloader_drop_last=True,
)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
"@dataclass\n", | |
"class TrainingArguments(transformers.TrainingArguments):\n", | |
" cache_dir: str | None = field(default=None)\n", | |
" model_max_length: int = field(\n", | |
" default=4096,\n", | |
" metadata={\n", | |
" \"help\": (\n", | |
" \"Maximum sequence length. Sequences will be right padded (and possibly truncated).\"\n", | |
" )\n", | |
" },\n", | |
" )\n", | |
" dataloader_drop_last: bool = field(default=True)\n", | |
" bf16: bool = field(default=True)\n", | |
"\n", | |
TrainingArguments = transformers.TrainingArguments | |
training_args = TrainingArguments( | |
output_dir="/tmp/eagle_bf16", | |
num_train_epochs=4, | |
per_device_train_batch_size=1, | |
per_device_eval_batch_size=1, | |
bf16=True, | |
dataloader_drop_last=True, | |
) |
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 102 to 106 (and the
instantiation at line 108), a custom @dataclass subclass of
transformers.TrainingArguments is being used just to toggle dataloader_drop_last
and bf16; remove this subclass entirely and stop maintaining a custom dataclass.
Replace uses of the custom class with the standard
transformers.TrainingArguments and pass dataloader_drop_last=True and bf16=True
as keyword arguments when creating the TrainingArguments instance (adjust the
instantiation at line 108 accordingly).
"trtllm_serve_script = f\"\"\"trtllm-serve {base_model} \\\\\n", | ||
" --host 0.0.0.0 \\\\\n", | ||
" --port 8000 \\\\\n", | ||
" --backend pytorch \\\\\n", | ||
" --max_batch_size 32 \\\\\n", | ||
" --max_num_tokens 8192 \\\\\n", | ||
" --max_seq_len 8192 \\\\\n", | ||
" --extra_llm_api_options /tmp/extra-llm-api-config.yml\n", | ||
"\"\"\"\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Bound server seq/token limits to model/tokenizer.
Hard‑coding 8192 may conflict with the base model’s limit. Use dynamic values.
Apply:
-trtllm_serve_script = f"""trtllm-serve {base_model} \\
+trtllm_serve_script = f"""trtllm-serve {base_model} \\
--host 0.0.0.0 \\
--port 8000 \\
--backend pytorch \\
--max_batch_size 32 \\
- --max_num_tokens 8192 \\
- --max_seq_len 8192 \\
+ --max_num_tokens {tokenizer.model_max_length} \\
+ --max_seq_len {tokenizer.model_max_length} \\
--extra_llm_api_options /tmp/extra-llm-api-config.yml
"""
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
"trtllm_serve_script = f\"\"\"trtllm-serve {base_model} \\\\\n", | |
" --host 0.0.0.0 \\\\\n", | |
" --port 8000 \\\\\n", | |
" --backend pytorch \\\\\n", | |
" --max_batch_size 32 \\\\\n", | |
" --max_num_tokens 8192 \\\\\n", | |
" --max_seq_len 8192 \\\\\n", | |
" --extra_llm_api_options /tmp/extra-llm-api-config.yml\n", | |
"\"\"\"\n", | |
"trtllm_serve_script = f\"\"\"trtllm-serve {base_model} \\\\\n", | |
" --host 0.0.0.0 \\\\\n", | |
" --port 8000 \\\\\n", | |
" --backend pytorch \\\\\n", | |
" --max_batch_size 32 \\\\\n", | |
" --max_num_tokens {tokenizer.model_max_length} \\\\\n", | |
" --max_seq_len {tokenizer.model_max_length} \\\\\n", | |
" --extra_llm_api_options /tmp/extra-llm-api-config.yml\n", | |
"\"\"\"\n", |
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 173 to 181, the
script hardcodes --max_num_tokens and --max_seq_len to 8192 which can exceed the
base model/tokenizer limits; change it to compute safe limits at runtime by
loading the model or tokenizer (e.g., tokenizer.model_max_length or
model.config.max_position_embeddings), then set max_num_tokens and max_seq_len
to the smaller of that model limit and your desired target (with a sensible
fallback value if the config is missing). Ensure the trtllm_serve_script string
injects these computed integers rather than the literal 8192 so the server flags
respect the actual model/tokenizer capacity.
"import subprocess\n", | ||
"import threading\n", | ||
"\n", | ||
"# Generate a unique container name so we can stop/remove it later\n", | ||
"container_name = \"trtllm_serve_spec\"\n", | ||
"\n", | ||
"docker_cmd = [\n", | ||
" \"docker\",\n", | ||
" \"run\",\n", | ||
" \"--rm\",\n", | ||
" \"--net\",\n", | ||
" \"host\",\n", | ||
" \"--shm-size=2g\",\n", | ||
" \"--ulimit\",\n", | ||
" \"memlock=-1\",\n", | ||
" \"--ulimit\",\n", | ||
" \"stack=67108864\",\n", | ||
" \"--gpus\",\n", | ||
" \"all\",\n", | ||
" \"-v\",\n", | ||
" \"/tmp:/tmp\",\n", | ||
" \"--name\",\n", | ||
" container_name,\n", | ||
" \"nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2\",\n", | ||
" \"bash\",\n", | ||
" \"-c\",\n", | ||
" \"bash /tmp/trtllm_serve.sh\",\n", | ||
"]\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Mount HF cache (+token) and increase shared memory for TRT‑LLM.
Without the HF cache/token in the container, model download can fail; 2g shm is often too small.
Apply:
- "--shm-size=2g",
+ "--shm-size=32g",
@@
- "-v",
- "/tmp:/tmp",
+ "-v", "/tmp:/tmp",
+ "-v", f"{hf_cache_dir}:/root/.cache/huggingface",
+ "-e", "HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface",
+ "-e", "HF_HUB_ENABLE_HF_TRANSFER=1",
+ # optionally pass a token if needed:
+ # "-e", f"HF_TOKEN={os.environ.get('HF_TOKEN','')}",
Also consider pre‑removing existing containers:
-# Generate a unique container name so we can stop/remove it later
-container_name = "trtllm_serve_spec"
+# Use a deterministic name and ensure it is not left over from prior runs
+container_name = "trtllm_serve_spec"
+subprocess.call(["docker","rm","-f",container_name])
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 220–247, the Docker
run command lacks mounting the Hugging Face cache/token and uses too small
shared memory; update the docker_cmd to: add volume mounts mapping the host HF
cache and token into the container (e.g. host ~/.cache/huggingface -> container
/root/.cache/huggingface and host ~/.huggingface or token file -> container
/root/.huggingface or appropriate path) so model downloads and auth work inside
the container, increase --shm-size from "2g" to a larger value (e.g. "8g" or
"16g") to avoid OOM on TRT‑LLM, and add a pre-run step to remove any existing
container with the same name (docker rm -f <container_name>) before starting the
new container.
"import json\n", | ||
"import requests\n", | ||
"\n", | ||
"from modelopt.torch.export import export_hf_checkpoint\n", | ||
"payload = {\n", | ||
" \"model\": base_model,\n", | ||
" \"messages\": [\n", | ||
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", | ||
" {\"role\": \"user\", \"content\": \"Tell me about speculative decoding.\"},\n", | ||
" ],\n", | ||
" \"max_tokens\": 512,\n", | ||
" \"temperature\": 0,\n", | ||
" \"chat_template\": tokenizer.chat_template,\n", | ||
"}\n", | ||
"headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\n", | ||
"\n", | ||
"# Move meta tensor back to device before exporting.\n", | ||
"remove_hook_from_module(model, recurse=True)\n", | ||
"response = requests.post(\n", | ||
" \"http://localhost:8000/v1/chat/completions\", headers=headers, data=json.dumps(payload)\n", | ||
")\n", | ||
"output = response.json()\n", | ||
"\n", | ||
"export_hf_checkpoint(\n", | ||
" model,\n", | ||
" export_dir=\"/tmp/hf_ckpt\",\n", | ||
")" | ||
"print(output)" | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Harden the request and drop non‑standard fields.
OpenAI Chat Completions may not accept chat_template; add error handling.
Apply:
- "chat_template": tokenizer.chat_template,
}
headers = {"Content-Type": "application/json", "Accept": "application/json"}
-response = requests.post(
- "http://localhost:8000/v1/chat/completions", headers=headers, data=json.dumps(payload)
-)
-output = response.json()
-
-print(output)
+resp = requests.post(
+ "http://localhost:8000/v1/chat/completions",
+ headers=headers,
+ data=json.dumps(payload),
+ timeout=60,
+)
+resp.raise_for_status()
+print(resp.json())
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
"import json\n", | |
"import requests\n", | |
"\n", | |
"from modelopt.torch.export import export_hf_checkpoint\n", | |
"payload = {\n", | |
" \"model\": base_model,\n", | |
" \"messages\": [\n", | |
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", | |
" {\"role\": \"user\", \"content\": \"Tell me about speculative decoding.\"},\n", | |
" ],\n", | |
" \"max_tokens\": 512,\n", | |
" \"temperature\": 0,\n", | |
" \"chat_template\": tokenizer.chat_template,\n", | |
"}\n", | |
"headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\n", | |
"\n", | |
"# Move meta tensor back to device before exporting.\n", | |
"remove_hook_from_module(model, recurse=True)\n", | |
"response = requests.post(\n", | |
" \"http://localhost:8000/v1/chat/completions\", headers=headers, data=json.dumps(payload)\n", | |
")\n", | |
"output = response.json()\n", | |
"\n", | |
"export_hf_checkpoint(\n", | |
" model,\n", | |
" export_dir=\"/tmp/hf_ckpt\",\n", | |
")" | |
"print(output)" | |
] | |
import json | |
import requests | |
payload = { | |
"model": base_model, | |
"messages": [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": "Tell me about speculative decoding."}, | |
], | |
"max_tokens": 512, | |
"temperature": 0, | |
} | |
headers = {"Content-Type": "application/json", "Accept": "application/json"} | |
resp = requests.post( | |
"http://localhost:8000/v1/chat/completions", | |
headers=headers, | |
data=json.dumps(payload), | |
timeout=60, | |
) | |
resp.raise_for_status() | |
print(resp.json()) |
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 282 to 303, the POST
payload includes a non-standard "chat_template" field and lacks proper error
handling; remove "chat_template" from the payload, send JSON using requests'
json= parameter (not data=json.dumps), check response.status_code and raise or
print a helpful error when it's not 200, wrap the request in a try/except to
catch requests.exceptions (timeout/connection errors) and set a reasonable
timeout, and ensure headers remain Content-Type: application/json and Accept:
application/json.
"import json\n", | ||
"import requests\n", | ||
"\n", | ||
"payload = {\n", | ||
" \"model\": base_model,\n", | ||
" \"messages\": [\n", | ||
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", | ||
" {\"role\": \"user\", \"content\": \"Tell me about speculative decoding.\"},\n", | ||
" ],\n", | ||
" \"max_tokens\": 512,\n", | ||
" \"temperature\": 0,\n", | ||
"}\n", | ||
"headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\n", | ||
"\n", | ||
"#Send request to the SGLang server\n", | ||
"response = requests.post(\n", | ||
" \"http://localhost:30000/v1/chat/completions\", headers=headers, data=json.dumps(payload)\n", | ||
")\n", | ||
"output = response.json()\n", | ||
"\n", | ||
"print(output['choices'][0]['message']['content'])" | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling on the SGLang request.
Apply:
-response = requests.post(
+response = requests.post(
"http://localhost:30000/v1/chat/completions", headers=headers, data=json.dumps(payload)
)
-output = response.json()
-
-print(output['choices'][0]['message']['content'])
+response.raise_for_status()
+output = response.json()
+print(output.get("choices", [{}])[0].get("message", {}).get("content", output))
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
"import json\n", | |
"import requests\n", | |
"\n", | |
"payload = {\n", | |
" \"model\": base_model,\n", | |
" \"messages\": [\n", | |
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", | |
" {\"role\": \"user\", \"content\": \"Tell me about speculative decoding.\"},\n", | |
" ],\n", | |
" \"max_tokens\": 512,\n", | |
" \"temperature\": 0,\n", | |
"}\n", | |
"headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\n", | |
"\n", | |
"#Send request to the SGLang server\n", | |
"response = requests.post(\n", | |
" \"http://localhost:30000/v1/chat/completions\", headers=headers, data=json.dumps(payload)\n", | |
")\n", | |
"output = response.json()\n", | |
"\n", | |
"print(output['choices'][0]['message']['content'])" | |
] | |
import json | |
import requests | |
payload = { | |
"model": base_model, | |
"messages": [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": "Tell me about speculative decoding."}, | |
], | |
"max_tokens": 512, | |
"temperature": 0, | |
} | |
headers = {"Content-Type": "application/json", "Accept": "application/json"} | |
#Send request to the SGLang server | |
response = requests.post( | |
"http://localhost:30000/v1/chat/completions", headers=headers, data=json.dumps(payload) | |
) | |
response.raise_for_status() | |
output = response.json() | |
print(output.get("choices", [{}])[0].get("message", {}).get("content", output)) |
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 423 to 444, the code
sends the SGLang request without any error handling; wrap the requests.post call
in a try/except that catches requests.exceptions.RequestException (use a
reasonable timeout), check response.raise_for_status() or response.ok and
log/raise a descriptive error including response.status_code and response.text
if the call failed, then safely parse JSON inside a try/except for
ValueError/JSONDecodeError and handle missing keys before printing (provide a
clear fallback or error message).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (10)
examples/speculative_decoding/example.ipynb (10)
17-19
: Replace git clone with huggingface_hub to avoid git‑lfs pitfalls.Use snapshot_download for reliability and caching.
-%%sh -git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater +from huggingface_hub import snapshot_download +dataset_dir = "/tmp/Daring-Anteater" +snapshot_download( + repo_id="nvidia/Daring-Anteater", + repo_type="dataset", + local_dir=dataset_dir, + local_dir_use_symlinks=False, +)
66-73
: Align tokenizer limits to model; set padding side; avoid ChatML fallback.Hardcoding 1024 may truncate; ChatML fallback can mismatch Llama‑3.2.
-tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=1024) -tokenizer.pad_token_id = tokenizer.eos_token_id -if tokenizer.chat_template is None: - tokenizer.chat_template = ( - "{%- for message in messages %}" - "{{- '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}" - "{%- endfor %}" - ) +model_max_len = getattr(getattr(model, "config", None), "max_position_embeddings", None) or 2048 +tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=model_max_len) +tokenizer.padding_side = "right" +tokenizer.pad_token_id = tokenizer.eos_token_id +# Avoid overriding tokenizer.chat_template with a generic template.
96-99
: Stream dataset or use datasets library; avoid loading entire JSONL into memory.-with open("/tmp/Daring-Anteater/train.jsonl") as f: - data_json = [json.loads(line) for line in f] -train_dataset = LazySupervisedDataset(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer) -eval_dataset = LazySupervisedDataset(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer) +from datasets import load_dataset +ds = load_dataset("json", data_files={"/all": "/tmp/Daring-Anteater/train.jsonl"})["/all"] +splits = ds.train_test_split(test_size=0.05, seed=42) +train_dataset = LazySupervisedDataset(splits["train"], tokenizer=tokenizer) +eval_dataset = LazySupervisedDataset(splits["test"], tokenizer=tokenizer)
102-113
: Don’t subclass TrainingArguments for minor toggles; pass flags directly.-from dataclasses import dataclass, field -@dataclass -class TrainingArguments(transformers.TrainingArguments): - dataloader_drop_last: bool = field(default=True) - bf16: bool = field(default=True) - - -training_args = TrainingArguments( +TrainingArguments = transformers.TrainingArguments +training_args = TrainingArguments( output_dir="/tmp/eagle_bf16", num_train_epochs=4, per_device_train_batch_size=1, per_device_eval_batch_size=1, + bf16=True, + dataloader_drop_last=True, )
122-122
: Avoid private API Trainer._move_model_to_device().Let Trainer handle placement or call model.to(...) explicitly.
-trainer._move_model_to_device(model, trainer.args.device) +# Let Trainer manage device placement.
173-181
: Bound max tokens/seq to tokenizer/model; don’t hardcode 8192.- --max_num_tokens 8192 \ - --max_seq_len 8192 \ + --max_num_tokens {tokenizer.model_max_length} \ + --max_seq_len {tokenizer.model_max_length} \
220-247
: Harden TRT‑LLM docker run: shm, HF cache/token, cleanup.-# Generate a unique container name so we can stop/remove it later -container_name = "trtllm_serve_spec" +container_name = "trtllm_serve_spec" +subprocess.call(["docker","rm","-f",container_name]) @@ - "--shm-size=2g", + "--shm-size=32g", @@ - "-v", - "/tmp:/tmp", + "-v", "/tmp:/tmp", + "-v", f"{os.path.expanduser('~')}/.cache/huggingface:/root/.cache/huggingface", + "-e", "HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface", + "-e", f"HF_TOKEN={os.environ.get('HF_TOKEN','')}",What is the latest stable NGC tag for nvcr.io/nvidia/tensorrt-llm/release and does trtllm-serve accept --backend pytorch with speculative_config?
282-303
: Remove non‑standard chat_template and add robust HTTP error handling.- "temperature": 0, - "chat_template": tokenizer.chat_template, + "temperature": 0, @@ -response = requests.post( - "http://localhost:8000/v1/chat/completions", headers=headers, data=json.dumps(payload) -) -output = response.json() - -print(output) +resp = requests.post( + "http://localhost:8000/v1/chat/completions", + headers=headers, + json=payload, + timeout=60, +) +resp.raise_for_status() +print(resp.json())
423-444
: Add error handling to SGLang request.-response = requests.post( - "http://localhost:30000/v1/chat/completions", headers=headers, data=json.dumps(payload) -) -output = response.json() - -print(output) +resp = requests.post( + "http://localhost:30000/v1/chat/completions", + headers=headers, + json=payload, + timeout=60, +) +resp.raise_for_status() +print(resp.json().get("choices", [{}])[0].get("message", {}).get("content"))
39-60
: Do not mutate EAGLE3_DEFAULT_CFG in place; deep‑copy before updates.Global mutation risks bleed‑through across runs/imports.
Note: I’m aware of the prior learning that ALGO_TO_CONFIG yields deep copies in tests; here we’re using EAGLE3_DEFAULT_CFG directly, so copy locally to be safe.
-from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG +from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG +import copy @@ -# Read Default Config for EAGLE3 -config = EAGLE3_DEFAULT_CFG["config"] +# Read Default Config for EAGLE3 (local copy) +config = copy.deepcopy(EAGLE3_DEFAULT_CFG["config"])
🧹 Nitpick comments (2)
examples/speculative_decoding/example.ipynb (2)
375-387
: Pass HF token into SGLang container and pre‑remove any stale container.-container_name = "sglang_serve_spec" +container_name = "sglang_serve_spec" +subprocess.call(["docker","rm","-f",container_name]) @@ - "-v", f"{hf_cache_dir}:/root/.cache/huggingface", + "-v", f"{hf_cache_dir}:/root/.cache/huggingface", + "-e", "HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface", + "-e", f"HF_TOKEN={os.environ.get('HF_TOKEN','')}",
93-94
: Import OK — eagle_utils present at examples/speculative_decoding/eagle_utils.pyDefines LazySupervisedDataset and DataCollatorWithPadding, so "from eagle_utils import ..." works when example.ipynb is run from that directory. If the notebook will be executed from other CWDs or the module reused elsewhere, either add examples/speculative_decoding to PYTHONPATH, convert the folder to a package and use a relative import, or install the module.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/speculative_decoding/example.ipynb
(12 hunks)
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-05T19:10:36.393Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
examples/speculative_decoding/example.ipynb
📚 Learning: 2025-09-05T19:10:36.393Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, the EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG configurations in config.py are already deepcopied, so additional deepcopy calls are not needed when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
examples/speculative_decoding/example.ipynb
📚 Learning: 2025-09-05T19:10:36.393Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, EAGLE1_DEFAULT_CFG in config.py already uses deepcopy(default_eagle_config) in its definition, making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
examples/speculative_decoding/example.ipynb
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (5)
examples/speculative_decoding/example.ipynb (5)
147-153
: LGTM: HF Unified checkpoint export is correctly invoked.
318-320
: LGTM: Cleanup step present to remove TRT‑LLM container.
44-47
: Access to meta‑llama models often requires an HF token.If the base model isn’t in the mounted HF cache, from_pretrained will fail in containers. Ensure HF_TOKEN is set and the cache is mounted (see docker comments).
336-345
: Confirmed: SGLang EAGLE3 CLI flags are correct.
--speculative-algorithm EAGLE3, --speculative-draft-model-path, and --speculative-num-draft-tokens are valid/current for sglang.launch_server; optional tuning flags include --speculative-num-steps and --speculative-eagle-topk.
114-121
: Incorrect — EAGLE conversion/config already handles freezing the base model.The conversion uses eagle_freeze_base_model (default True) and passes it into the converted model which sets base params' requires_grad=False; the EAGLE head lives under model.eagle_module (e.g. eagle_module.eagle_lm_head, eagle_module.fc, eagle_module.decoder.layers.*). See modelopt/torch/speculative/config.py, modelopt/torch/speculative/eagle/conversion.py, modelopt/torch/speculative/plugins/megatron_eagle.py and modelopt/torch/speculative/plugins/transformers.py for the freeze logic.
Likely an incorrect or invalid review comment.
Signed-off-by: h-guo18 <[email protected]> Updated example.ipynb with sglang steps Updated haoguo's demo notebook with sglang instructions fixed --model-path to --model Commit with sign. Signed-off-by: Jamie Li <[email protected]>
f6b85a9
to
f85794e
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #316 +/- ##
=======================================
Coverage 73.87% 73.88%
=======================================
Files 172 172
Lines 17415 17444 +29
=======================================
+ Hits 12865 12888 +23
- Misses 4550 4556 +6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (9)
examples/speculative_decoding/example.ipynb (9)
17-18
: Avoid git‑LFS headaches: use huggingface_hub instead of git clone.Replace shell clone with snapshot_download for reliability and caching.
-!git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater +from huggingface_hub import snapshot_download +snapshot_download( + repo_id="nvidia/Daring-Anteater", + repo_type="dataset", + local_dir="/tmp/Daring-Anteater", + local_dir_use_symlinks=False, +)
64-72
: Align tokenizer limits to the model and set padding side explicitly.-# Prepare Tokenizer -tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=1024) -tokenizer.pad_token_id = tokenizer.eos_token_id +# Prepare Tokenizer +cfg = transformers.AutoConfig.from_pretrained(base_model) +max_len = int(getattr(cfg, "max_position_embeddings", 2048)) +tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=max_len) +tokenizer.pad_token_id = tokenizer.eos_token_id +tokenizer.padding_side = "right" if tokenizer.chat_template is None:
95-98
: Avoid loading the entire JSONL into memory; stream or use datasets.-with open("/tmp/Daring-Anteater/train.jsonl") as f: - data_json = [json.loads(line) for line in f] -train_dataset = LazySupervisedDataset(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer) -eval_dataset = LazySupervisedDataset(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer) +def stream_jsonl(path): + with open(path) as f: + for line in f: + yield json.loads(line) +all_data = list(stream_jsonl("/tmp/Daring-Anteater/train.jsonl")) +split = int(len(all_data) * 0.95) +train_dataset = LazySupervisedDataset(all_data[:split], tokenizer=tokenizer) +eval_dataset = LazySupervisedDataset(all_data[split:], tokenizer=tokenizer)(Or switch to datasets.load_dataset("json") with train_test_split.)
101-112
: Don’t subclass TrainingArguments for minor toggles.-from dataclasses import dataclass, field +from dataclasses import dataclass, field # (no longer needed if not subclassing) @@ -@dataclass -class TrainingArguments(transformers.TrainingArguments): - dataloader_drop_last: bool = field(default=True) - bf16: bool = field(default=True) +TrainingArguments = transformers.TrainingArguments @@ training_args = TrainingArguments( output_dir="/tmp/eagle_bf16", num_train_epochs=4, per_device_train_batch_size=1, per_device_eval_batch_size=1, + bf16=True, + dataloader_drop_last=True, )
121-121
: Remove private API call Trainer._move_model_to_device().This is unstable across HF releases; let Trainer handle device placement.
-trainer._move_model_to_device(model, trainer.args.device) +# Let Trainer manage device placement
172-180
: Bind TRT‑LLM token/sequence limits to the model.trtllm_serve_script = f"""trtllm-serve {base_model} \ @@ - --max_num_tokens 8192 \ - --max_seq_len 8192 \ + --max_num_tokens {tokenizer.model_max_length} \ + --max_seq_len {tokenizer.model_max_length} \
219-246
: Container ergonomics: mount HF cache/token, bump shm, and pre‑remove old container.import subprocess import threading +import os @@ -# Generate a unique container name so we can stop/remove it later +# Generate a unique container name so we can stop/remove it later container_name = "trtllm_serve_spec" +home_dir = os.path.expanduser("~") +hf_cache_dir = os.path.join(home_dir, ".cache", "huggingface") +os.makedirs(hf_cache_dir, exist_ok=True) +subprocess.call(["docker", "rm", "-f", container_name]) docker_cmd = [ "docker", "run", "--rm", "--net", "host", - "--shm-size=2g", + "--shm-size=32g", "--ulimit", "memlock=-1", "--ulimit", "stack=67108864", + "--ipc=host", "--gpus", "all", - "-v", - "/tmp:/tmp", + "-v", "/tmp:/tmp", + "-v", f"{hf_cache_dir}:/root/.cache/huggingface", + "-e", "HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface", + "-e", "HF_HUB_ENABLE_HF_TRANSFER=1", + # optionally: "-e", f"HF_TOKEN={os.environ.get('HF_TOKEN','')}", "--name", container_name, "nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2", "bash", "-c", "bash /tmp/trtllm_serve.sh", ]
281-301
: Harden the TRT‑LLM request; remove non‑standard fields.payload = { "model": base_model, @@ - "temperature": 0, - "chat_template": tokenizer.chat_template, + "temperature": 0, } headers = {"Content-Type": "application/json", "Accept": "application/json"} -response = requests.post( - "http://localhost:8000/v1/chat/completions", headers=headers, data=json.dumps(payload) -) -output = response.json() - -print(output) +resp = requests.post( + "http://localhost:8000/v1/chat/completions", + headers=headers, + json=payload, + timeout=60, +) +resp.raise_for_status() +print(resp.json())
421-442
: Add basic error handling on the SGLang request.-#Send request to the SGLang server -response = requests.post( - "http://localhost:30000/v1/chat/completions", headers=headers, data=json.dumps(payload) -) -output = response.json() - -print(output) +# Send request to the SGLang server +resp = requests.post( + "http://localhost:30000/v1/chat/completions", + headers=headers, + json=payload, + timeout=60, +) +resp.raise_for_status() +print(resp.json())
🧹 Nitpick comments (5)
examples/speculative_decoding/example.ipynb (5)
34-39
: Don’t mutate imported defaults; work on a copy of EAGLE3 config.Mutating EAGLE3_DEFAULT_CFG in place can leak across cells/runs. Use a local deepcopy.
Using your retrieved learning: tests may not need extra deepcopy, but in a long‑lived notebook kernel avoiding global mutation is safer.
import transformers +import copy import modelopt.torch.opt as mto import modelopt.torch.speculative as mtsp from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG @@ -# Read Default Config for EAGLE3 -config = EAGLE3_DEFAULT_CFG["config"] +# Read Default Config for EAGLE3 (local copy to avoid global mutation) +config = copy.deepcopy(EAGLE3_DEFAULT_CFG["config"])Also applies to: 48-59
43-46
: Note: base model may require an HF token.Llama 3.2 repos often need auth. Add a short note or check for HF_TOKEN/env login to reduce friction.
# Load original HF model base_model = "meta-llama/Llama-3.2-1B" +import os +if not os.environ.get("HF_TOKEN"): + print("Hint: this model may require Hugging Face auth (HF_TOKEN or `huggingface-cli login`).") model = transformers.AutoModelForCausalLM.from_pretrained( base_model, torch_dtype="auto", device_map="cuda" )
362-405
: SGLang container: ensure HF cache env, optional token, and cleanup.import subprocess import threading import os @@ container_name = "sglang_serve_spec" @@ -docker_cmd = [ +subprocess.call(["docker", "rm", "-f", container_name]) +docker_cmd = [ "docker", "run", "--rm", "--net", "host", "--shm-size=32g", "--gpus", "all", "-v", f"{hf_cache_dir}:/root/.cache/huggingface", "-v", "/tmp:/tmp", "--ipc=host", + "-e", "HUGGINGFACE_HUB_CACHE=/root/.cache/huggingface", + "-e", "HF_HUB_ENABLE_HF_TRANSFER=1", + # optionally: "-e", f"HF_TOKEN={os.environ.get('HF_TOKEN','')}", "--name", container_name, - "lmsysorg/sglang:latest", + "lmsysorg/sglang:latest", "bash", "-c", "bash /tmp/sglang_serve.sh", ]Also consider pinning the image tag instead of “latest” for reproducibility.
271-273
: Optional: add readiness polling instead of manual log watching.Example snippet to insert before sending requests:
import time, requests def wait_ready(url, timeout=120): t0=time.time() while time.time()-t0<timeout: try: r=requests.get(url, timeout=2) if r.status_code in (200,404): return True except Exception: pass time.sleep(2) raise TimeoutError(f"Service not ready: {url}") wait_ready("http://localhost:8000") # TRT‑LLM # wait_ready("http://localhost:30000") # SGLangAlso applies to: 411-413
113-120
: Verify only the EAGLE head is trainable.Sanity‑check that base model params are frozen post‑convert; prevents unintended fine‑tuning.
# Diagnostic check (optional cell) trainable = [n for n,p in model.named_parameters() if p.requires_grad] print(f"Trainable params: {len(trainable)}") assert all("eagle" in n.lower() for n in trainable), "Unexpected trainable params outside EAGLE head."
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/speculative_decoding/example.ipynb
(12 hunks)
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-05T19:10:36.393Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
examples/speculative_decoding/example.ipynb
📚 Learning: 2025-09-05T19:10:36.393Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, the EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG configurations in config.py are already deepcopied, so additional deepcopy calls are not needed when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
examples/speculative_decoding/example.ipynb
📚 Learning: 2025-09-05T19:10:36.393Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.393Z
Learning: In the TensorRT-Model-Optimizer codebase, EAGLE1_DEFAULT_CFG in config.py already uses deepcopy(default_eagle_config) in its definition, making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
examples/speculative_decoding/example.ipynb
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-docs
- GitHub Check: code-quality
"# Convert Model for eagle speculative decoding\n", | ||
"mtsp.convert(model, [(\"eagle\", config)])\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Verify algorithm naming is consistent across ModelOpt, TRT‑LLM, and SGLang.
mtsp.convert uses "eagle" while TRT‑LLM yaml says "Eagle" and SGLang flag uses "EAGLE3". Confirm accepted identifiers/case to avoid runtime mismatches.
Also applies to: 189-191, 339-342
🌐 Web query:
Check:
1) modelopt.torch.speculative mtsp.convert adapter names (is "eagle" correct for EAGLE3?).
2) TRT-LLM extra_llm_api_config 'speculative_config.decoding_type' accepted values and case ("Eagle" vs "EAGLE3").
3) SGLang --speculative-algorithm accepted values for EAGLE3 and required draft model flags.
💡 Result:
Answers — concise:
- modelopt.torch.speculative mtsp.convert adapter name
- Use "eagle" (lowercase) as the mode string for mtsp.convert. [1][2]
- TRT-LLM extra_llm_api_config / speculative_config.decoding_type (case/values)
- TRT-LLM uses lowercase mode names; the speculative/decoding mode values include (examples) "eagle", "medusa", "lookahead", "redrafter", "top_k", "top_p", "top_k_top_p", etc. Use "eagle" (lowercase) for EAGLE-style speculative decoding. (Note: EAGLE‑2 support is limited/not supported in some Triton/TensorRT-LLM deployments per docs.) [3][4]
- SGLang --speculative-algorithm accepted values and draft-model flags
- SGLang accepts "EAGLE3" (capitalization as shown) for the flag --speculative-algorithm (it documents EAGLE3 explicitly). SGLang requires a draft model path (speculative_draft_model_path / --speculative-draft-model-path) when using EAGLE/EAGLE3; other EAGLE-related flags include speculative_num_steps, speculative_eagle_topk, speculative_num_draft_tokens, etc. [5][6]
Sources
- [1] NVIDIA TensorRT-Model-Optimizer speculative_decoding API (mtsp.convert modes include "eagle"). (nvidia.github.io)
- [2] TensorRT-Model-Optimizer examples / eagle module docs. (github.com)
- [3] TensorRT-LLM / trtllm-build and model_config docs (speculative_decoding_mode / decoding_mode values; "eagle" lowercase). (nvidia.github.io)
- [4] TensorRT-LLM speculative-decoding tutorial / notes (EAGLE details / EAGLE-2 support notes). (docs.nvidia.com)
- [5] SGLang docs: backend/advanced pages show --speculative-algorithm EAGLE3 and required speculative_draft_model_path. (docs.sglang.ai)
- [6] Hugging Face SGLang EAGLE3 draft-model examples (usage showing --speculative-algorithm EAGLE3 + draft model). (huggingface.co)
If you want, I can paste exact example command lines from the docs for each case.
Citations:
- 1: https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.speculative.speculative_decoding.html?utm_source=openai
- 2: https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/examples/speculative_decoding/README.md?utm_source=openai
- 3: https://nvidia.github.io/TensorRT-LLM/commands/trtllm-build.html?utm_source=openai
- 4: https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/tutorials/Feature_Guide/Speculative_Decoding/TRT-LLM/README.html?utm_source=openai
- 5: https://docs.sglang.ai/backend/speculative_decoding.html?utm_source=openai
- 6: https://huggingface.co/lmsys/sglang-EAGLE3-Llama-4-Maverick-17B-128E-Instruct-v1?utm_source=openai
Align speculative-decoding names: use "eagle" for ModelOpt/TRT‑LLM; keep "EAGLE3" for SGLang and provide a draft model.
- examples/speculative_decoding/example.ipynb (61–62; also 189–191, 339–342): mtsp.convert already uses "eagle" — keep as-is.
- TRT‑LLM YAML/config: change any "Eagle" → "eagle" (search for speculative_config.decoding_type / decoding_type / speculative_decoding_mode) to match TRT‑LLM/ModelOpt accepted lowercase values.
- SGLang: continue using "--speculative-algorithm EAGLE3" and ensure a --speculative-draft-model-path / speculative_draft_model_path is supplied.
Add a one-line mapping in the example/README explaining that SGLang uses "EAGLE3" while ModelOpt/TRT‑LLM use "eagle" to avoid confusion.
🤖 Prompt for AI Agents
In examples/speculative_decoding/example.ipynb around lines 61-62 (and note
similar usages at 189-191 and 339-342), the notebook already uses
mtsp.convert(model, [("eagle", config)]) which is correct for ModelOpt/TRT‑LLM;
ensure any TRT‑LLM YAML/config keys that reference the speculative decoding type
use lowercase "eagle" (search for speculative_config.decoding_type,
decoding_type, speculative_decoding_mode and change "Eagle" → "eagle"); for
SGLang keep the CLI flag value as "EAGLE3" and ensure the SGLang invocation
provides a draft model via --speculative-draft-model-path (or
speculative_draft_model_path in configs); finally add a one-line note in the
example README mapping that SGLang uses "EAGLE3" while ModelOpt/TRT‑LLM use
"eagle" to avoid confusion.
What does this PR do?
**Type of change: documentation ** ?
Overview: Updated the example.ipynb for speculative decoding to deploy on sglang and trt-llm ?
Usage
# Add a code snippet demonstrating how to use this
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Changes
Breaking Changes
Documentation
Chores